-
Notifications
You must be signed in to change notification settings - Fork 6.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix overwrite bug when adding symbol to dictionary #5329
Open
lydianish
wants to merge
12
commits into
facebookresearch:main
Choose a base branch
from
lydianish:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This bug ignored the tokens that were meant to be overwritten and appends them to the end of the dictionary symbols. For example, a dictionary with 50K tokens that already has `<s>`, `</s>`, `<pad>` and `<unk>` with the #fairseq:overwrite tag will end up having 50004 tokens when loaded.
Assert that overwrite works as expected (i.e. ignoring the duplicates)
For backward compatibility with the existing models/pipelines that uses a flawed dictionary loaded from file (before the bug fix)
…tionary After fixing the behaviour of add_symbol, two of the unit tests were failing because they called the function with the default value of overwrite (False).
This ensures compatibility with all the calls to add_symbol across the repo (which overwrite by default, as in the original implementation). The only place where the value is explicitly changed is when loading the dictionary from file (which was the source of the bug). In a file you have to explicitly say whether the tokens should be overwritten or duplicated
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Before submitting
What does this PR do?
Fixes #3064.
Fixes #3705.
Fixes #1309.
TLDR; This PR fixes the bug that duplicates the symbols that were meant to be overwritten in the vocabulary file. See detailed explanation in this blog post.
Expected behavior:
A Dictionary object has an
indices
dict and two lists (symbols
andcounts
). By default, when loading a vocabulary from a file, a Dictionary instance is first created by adding 4 special tokens (<s>
,<pad>
,</s>
and<unk>
in that order). Then, all the entries from the file are appended to the Dictionary. If the vocabulary file already has some of the special tokens, their file entry should contain#fairseq:overwrite
, otherwise a "duplicate" error will be raised at runtime. Furthermore, during preprocessing, the saved dictionary should not contain any of the special symbols.Current behavior:
The
add_symbol
function is responsible for adding the symbols to the Dictionary. It has anoverwrite
argument that is set toTrue
when the corresponding line in the file has#fairseq:overwrite
. Rather than testingif word in self.indices and overwrite
, it is currently testingif word in self.indices and not overwrite
, which makes it ignore the case where the symbol should actually be overwritten. Hence, the symbol is appended to thesymbols
list, and its index is changed in theindices
dict. This results in duplicate symbols and incorrect indices. Generally, only the special symbols will be affected. However, because the number of special tokens is set during initialization, it remains correct.For example, a dictionary with 50K tokens that already has
<s>
,<pad>
,</s>
and<unk>
with the#fairseq:overwrite
tag will end up having 50004 tokens when loaded. This will also propagate to the subsequent model which will have an embedding dimension of 50004 instead of 50K. Also, withfairseq-preprocess
, the resulting dictionary will skip the first 4 special symbols but will still contain the duplicate ones.Domino effects and backward compatibility:
By fixing this bug, dictionary files will be loaded properly. However, this fix might cause problems in pipelines that use existing architectures and pretrained models because of the mismatch in sentencepiece encoding and/or embedding dimension.
For the sake of backward compatibility, a
#fairseq:duplicate
flag is introduced to ensure that duplicates are kept in the dictionary just like the bug. When used withfairseq-preprocess
, the produced dict.txt file will also write#fairseq:duplicate
next to the same symbols.PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Yes, I did 🙃